import optuna
import time
import logging
from optuna.trial import Trial
from hypersense.optimizer.base_optimizer import BaseOptimizer
from typing import Dict, Any, List, Optional, Union, Tuple


class OptunaOptimizer(BaseOptimizer):
    def __init__(self, sampler: Optional[optuna.samplers.BaseSampler] = None, **kwargs):
        super().__init__(**kwargs)
        self.sampler = sampler or optuna.samplers.TPESampler(seed=self.seed, n_startup_trials=0)
        if self.sampler is not None and not isinstance(self.sampler, optuna.samplers.BaseSampler):
            raise TypeError("sampler must be an instance of optuna.samplers.BaseSampler")
        self._validate_multi_objective()
        self.is_define_by_run = callable(self.space)
        self._init_study()

    def _init_study(self):
        if isinstance(self.mode, str):
            # Single objective
            direction = "minimize" if self.mode == "min" else "maximize"
            self.study = optuna.create_study(direction=direction, sampler=self.sampler)
        else:
            # Multi-objective
            directions = ["minimize" if m == "min" else "maximize" for m in self.mode]
            self.study = optuna.create_study(
                directions=directions,
                sampler=self.sampler,
            )

        # warm-start with points + rewards
        for config, reward in zip(self.points_to_evaluate, self.evaluated_rewards):
            self.study.add_trial(
                optuna.trial.create_trial(
                    params=config,
                    distributions=self._convert_distribution(config),
                    values=reward if isinstance(reward, list) else [reward],
                )
            )

    def _convert_distribution(self, config: Dict[str, Any]) -> Dict[str, optuna.distributions.BaseDistribution]:
        if isinstance(self.space, dict):
            return self.space  # assumed to be optuna.distributions
        else:
            raise ValueError("Space must be dict[optuna.distribution] for warm-start.")

    def _suggest_config(self, trial: Trial) -> Dict[str, Any]:
        if self.is_define_by_run:
            return self.space(trial) or {}
        else:
            config = {}
            for name, dist in self.space.items():
                config[name] = trial._suggest(name, dist)
            return config

    def _validate_multi_objective(self):
        if isinstance(self.metric, list) and isinstance(self.mode, list):
            if len(self.metric) != len(self.mode):
                raise ValueError("Length of `metric` and `mode` must be the same for multi-objective optimization.")

    def optimize(self):
        """
        Run the optimization process and record (config, result, elapsed_time) for each trial.
        Returns:
            List of tuples: (config, result, elapsed_time)
        """
        if not self.objective_fn:
            raise ValueError("Objective function must be provided for optimization.")
        optuna.logging.set_verbosity(optuna.logging.INFO if self.verbose else optuna.logging.WARNING)
        self.trial_history = []
        start = time.time()

        def _objective(trial: Trial):
            start_trial = time.time()
            config = self._suggest_config(trial)
            result = self.objective_fn(config)
            elapsed = time.time() - start_trial
            self.trial_history.append((config, result, elapsed))
            if self.early_stopping_fn and self.early_stopping_fn(self.trial_history):
                raise optuna.exceptions.TrialPruned()
            return result

        try:
            self.study.optimize(
                _objective,
                n_trials=self.max_trials,
                timeout=self.max_time,
                show_progress_bar=False,
            )
        except KeyboardInterrupt:
            print("Optimization interrupted.")

        self.elapsed_time = time.time() - start
        return self.trial_history

    def get_best_config(self, include_score: bool = False) -> Dict[str, Any]:
        trial = self.study.best_trial
        config = trial.params

        if include_score:
            # Return a list for multi-objective, otherwise return a single value
            score = trial.values if hasattr(trial, "values") else trial.value
            return {
                "params": config,
                "score": score,
                "elapsed_time": (round(self.elapsed_time, 4) if self.elapsed_time else None),
            }

        return config
